import random
import torch
from einops import rearrange

from comfy.ldm.modules.attention import optimized_attention

from .rave_utils import grid_to_list, list_to_grid, shuffle_indices, shuffle_tensors2


def padding_count(n_frames, grid_frame_count):
    remainder = n_frames % grid_frame_count
    if remainder == 0:
        return 0
    else:
        difference = grid_frame_count - remainder
        return difference


def get_rave_attention(index_bank=None, grid_size=2, seed=None, is_first_attention=False):
    def rave_attention(q, k, v, extra_options):
        batch_size, sequence_length, dim = q.shape
        n_heads = extra_options['n_heads']
        len_conds = len(extra_options['cond_or_uncond'])
        n_frames = batch_size // len_conds
        original_n_frames = n_frames

        grid_frame_count = grid_size * grid_size
        n_padding_frames = padding_count(n_frames, grid_frame_count)
        if n_padding_frames > 0:
            random.seed(seed)
            cond_qs = []
            cond_ks = []
            cond_vs = []
            padding_frames = [random.randint(
                0, n_frames-1) for _ in range(n_padding_frames)]
            for cond_idx in range(len_conds):
                start, end = cond_idx*n_frames, (cond_idx+1)*n_frames
                cond_q = q[start:end]
                cond_q = torch.cat([cond_q, cond_q[padding_frames]])
                cond_qs.append(cond_q)
                cond_k = k[start:end]
                cond_k = torch.cat([cond_k, cond_k[padding_frames]])
                cond_ks.append(cond_k)
                cond_v = v[start:end]
                cond_v = torch.cat([cond_v, cond_v[padding_frames]])
                cond_vs.append(cond_v)

            q = torch.cat(cond_qs)
            k = torch.cat(cond_ks)
            v = torch.cat(cond_vs)

        n_frames = n_frames + n_padding_frames

        # get h,w
        shape = extra_options['original_shape']
        oh, ow = shape[-2:]
        ratio = oh/ow
        d = sequence_length
        w = int((d/ratio)**(0.5))
        h = int(d/w)

        q = rearrange(q, 'b (h w) c -> b h w c', h=h, w=w)
        k = rearrange(k, 'b (h w) c -> b h w c', h=h, w=w)
        v = rearrange(v, 'b (h w) c -> b h w c', h=h, w=w)

        target_indexes = None
        if index_bank is not None:
            target_indexes = index_bank['target_indexes']
            if target_indexes is None or is_first_attention:
                target_indexes = shuffle_indices(n_frames, seed=seed)
            index_bank['target_indexes'] = target_indexes

        if target_indexes is None:
            target_indexes = shuffle_indices(n_frames, seed=seed)

        original_indexes = list(range(n_frames))
        qs = []
        ks = []
        vs = []

        for i in range(len_conds):
            start, end = i*n_frames, (i+1)*n_frames
            q[start:end] = shuffle_tensors2(
                q[start:end], original_indexes, target_indexes)
            qs.append(list_to_grid(q[start:end], grid_size))
            k[start:end] = shuffle_tensors2(
                k[start:end], original_indexes, target_indexes)
            ks.append(list_to_grid(k[start:end], grid_size))
            v[start:end] = shuffle_tensors2(
                v[start:end], original_indexes, target_indexes)
            vs.append(list_to_grid(v[start:end], grid_size))

        q = torch.cat(qs)
        k = torch.cat(ks)
        v = torch.cat(vs)

        q = rearrange(q, 'b h w c -> b (h w) c')
        k = rearrange(k, 'b h w c -> b (h w) c')
        v = rearrange(v, 'b h w c -> b (h w) c')

        out = optimized_attention(q, k, v, n_heads, None)

        gh, gw = grid_size*h, grid_size*w
        out = rearrange(out, 'b (h w) c -> b h w c', h=gh, w=gw)
        out = grid_to_list(out, grid_size)
        out = rearrange(out, 'b h w c -> b (h w) c')

        outs = []
        for i in range(len_conds):
            start, end = i*n_frames, (i+1)*n_frames
            cond_out = shuffle_tensors2(
                out[start:end], target_indexes, original_indexes)
            cond_out = cond_out[:original_n_frames]
            outs.append(cond_out)

        return torch.cat(outs)

    return rave_attention
